Decoding strategies and parameters¶

We will be using transformers to generate next token sequences from a small model, and visualize the changes hyperparameters do to the probability distribution of the next token.

In [1]:
!pip install -q "transformers>=4.45.0" hf_transfer torch accelerate bitsandbytes plotly scipy tqdm

import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
[notice] A new release of pip is available: 23.0.1 -> 24.3.1
[notice] To update, run: pip install --upgrade pip

Load the model and tokenizer¶

We will be using a non-instruct model for this excercise.

Non-instruct models, also known as causal language models, are designed for open-ended text generation tasks.

They predict the most likely next token based on the previous context, and do not follow specific commands like instruction-tuned models.

Careful prompting and tuning is required to produce usable outputs.

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# We will be using a non-instruct model for this excercise 
model_id = "Qwen/Qwen2.5-1.5B"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id, device_map=device)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
/home/stefan/laborator/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Example of open-ended text generation¶

In [3]:
from transformers import TextStreamer

prompt = "2 + 2 = "

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(inputs.input_ids, max_new_tokens=256, streamer=TextStreamer(tokenizer))
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
2 + 2 = 
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
4 \) and \( 2^2 + 2^2 = 8 \), so \( 4 \) and \( 8 \) are not coprime.
- \( 2^3 + 2^3 = 16 \) and \( 2^2 + 2^2 = 8 \), so \( 16 \) and \( 8 \) are not coprime.
- \( 2^4 + 2^4 = 32 \) and \( 2^2 + 2^2 = 8 \), so \( 32 \) and \( 8 \) are not coprime.
- \( 2^5 + 2^5 = 64 \) and \( 2^2 + 2^2 = 8 \), so \( 64 \) and \( 8 \) are not coprime.
- \( 2^6 + 2^6 = 128 \) and \( 2^2 + 2^2 = 8 \), so \( 128 \) and \( 8 \) are not coprime.
- \( 2^7 + 2

Logits and scores¶

Logits are the raw, unnormalized scores for each token in the vocabulary.

Scores are logits on top of which different transformations were applied depending on the decoding strategy.

Final probabilities are obtained by applying the softmax function to the scores.

Retrieving logits and scores¶

Let's use the transformer's generate method to get a more detailed look at the output.

In [4]:
output_dict = model.generate(inputs.input_ids, max_new_tokens=1, output_scores=True, output_logits=True, return_dict_in_generate=True)

for key, value in output_dict.items():
    print(
        key, 
        value.shape if type(value) == torch.Tensor else f'Tuple of {value[0].shape}' if type(value[0]) == torch.Tensor else f'Tuple of tuple of {value[0][0].shape}'
    )
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
sequences torch.Size([1, 7])
scores Tuple of torch.Size([1, 151936])
logits Tuple of torch.Size([1, 151936])
past_key_values Tuple of tuple of torch.Size([1, 2, 6, 128])

sequences contains the generated sequence of tokens + last token.

In [5]:
output_dict['sequences']
Out[5]:
tensor([[ 17, 488, 220,  17, 284, 220,  19]], device='cuda:0')

logits are the raw, unnormalized scores for each token in the vocabulary.
Their shape is (1, vocab_size)

In [6]:
print(output_dict['logits'][0].shape)
output_dict['logits'][0][:,0:10]
torch.Size([1, 151936])
Out[6]:
tensor([[ 4.0467, -1.9739,  2.8914,  5.2675,  0.1136,  1.6309, -0.9729,  3.1858,
          1.3045,  1.6879]], device='cuda:0')

scores are the logits on top of which different transformations were applied depending on the decoding strategy.
Their shape is (1, vocab_size)

In [7]:
print(output_dict['scores'][0].shape)
output_dict['scores'][0][:,0:10]
torch.Size([1, 151936])
Out[7]:
tensor([[ 4.0467, -1.9739,  2.8914,  5.2675,  0.1136,  1.6309, -0.9729,  3.1858,
          1.3045,  1.6879]], device='cuda:0')

Confirming the size of the vocabulary.

In [8]:
len(tokenizer.vocab)
Out[8]:
151665

Visualizing the top-k logit values and probabilities¶

Taking the top-k logit values, we can see the most likely next token in the first position.

In [9]:
output_dict['logits'][0].topk(5)
Out[9]:
torch.return_types.topk(
values=tensor([[22.0941, 21.8434, 21.6329, 21.6252, 21.5669]], device='cuda:0'),
indices=tensor([[19, 16, 21, 15, 17]], device='cuda:0'))

Decoding the top-k indexes to tokens, we can understand how the the next token looks like.
Observe that there are also strange results, such as 'Ġ', 'Ġ\\', 'âij',. Those tokens contain linking characters that need to be decoded.

In [10]:
top_k = 15
# get top 5 results indexes
top_k_values = output_dict['logits'][0].topk(top_k).values
top_k_indexes = output_dict['logits'][0].topk(top_k).indices
# go from a tensor of shape (1, 5) to a tensor of shape (5, 1)
top_k_indexes = top_k_indexes.reshape(-1, 1)

tokenizer.convert_ids_to_tokens(top_k_indexes)
Out[10]:
['4',
 '1',
 '6',
 '0',
 '2',
 '3',
 '8',
 '5',
 '7',
 '9',
 'Ġ',
 'Ġ\\',
 'âij',
 'Ġ-',
 'Ġ(']

We can convert each token to a string and visualize the true next token result.

In [11]:
start_tokens = tokenizer.convert_ids_to_tokens(top_k_indexes)
for token in start_tokens:
    print(f'{token} => "{tokenizer.convert_tokens_to_string([token])}"')
4 => "4"
1 => "1"
6 => "6"
0 => "0"
2 => "2"
3 => "3"
8 => "8"
5 => "5"
7 => "7"
9 => "9"
Ġ => " "
Ġ\ => " \"
âij => "�"
Ġ- => " -"
Ġ( => " ("

The strange tokens were decoded to:

Ġ => " "
Ġ\ => " \"
âij => "�"
Ġ- => " -"
Ġ( => " ("

But there is still âij => "�" that appears unreadable. This token requires other token in its sequence to proberly decode.
Since we decoded each token separatly, the it failed for this usecase. Most likely this is a non-latin, composite character.


Now, let's visualize the top-k logit values.

In [12]:
import plotly.graph_objects as go
import numpy as np

# Convert tensors to numpy arrays and move to CPU 
top_k_logits_values_np = top_k_values.cpu().numpy().flatten()
top_k_logits_indexes_np = top_k_indexes.cpu().numpy().flatten()

# Convert token indices to actual tokens
start_tokens = tokenizer.convert_ids_to_tokens(top_k_logits_indexes_np)
strings = [tokenizer.convert_tokens_to_string([token]) for token in start_tokens]

# Create the bar chart
fig = go.Figure(data=[
    go.Bar(
        x=[f'"{string}" `{token}`' for token, string in zip(start_tokens, strings)],
        y=np.round(top_k_logits_values_np, 2),
        text=np.round(top_k_logits_values_np, 2),  
        textposition='auto'
    )
])

# Update layout for logarithmic y-axis and other customizations
fig.update_layout(
    title='Top 5 Token Outputs',
    xaxis_title='Tokens',
    yaxis_title='Log Probability',
    yaxis_type='log',  
)

# Show the plot
fig.show()

And top-k score probabilities.

In [13]:
from scipy.special import softmax

# Create the bar chart
fig = go.Figure(data=[
    go.Bar(
        x=[f'"{string}" `{token}`' for token, string in zip(start_tokens, strings)],
        y=np.round(softmax(top_k_logits_values_np), 2),
        text=np.round(softmax(top_k_logits_values_np), 2),  
        textposition='auto',
        marker_color='red'
    )
])

# Update layout for logarithmic y-axis and other customizations
fig.update_layout(
    title='Top 5 Token Probabilities',
    xaxis_title='Tokens',
    yaxis_title='Log Probability',
)

# Show the plot
fig.show()

We'll create a helper function to visualize the top-k logit values and probabilities for later use

In [14]:
from scipy.special import softmax

def visualize_top_k_tokens(logits, scores = [], k=5):
    tok_k = logits.topk(k)
    top_k_values = tok_k.values
    top_k_indexes = tok_k.indices

    # Convert tensors to numpy arrays and move to CPU 
    top_k_logits_values_np = top_k_values.cpu().numpy().flatten()
    top_k_logits_indexes_np = top_k_indexes.cpu().numpy().flatten()

    if torch.is_tensor(scores):
        scores = scores.cpu().numpy().flatten()
    top_k_scores_values_np = softmax(scores)[top_k_logits_indexes_np]

    # Convert token indices to actual tokens
    tokens = tokenizer.convert_ids_to_tokens(top_k_logits_indexes_np)
    strings = [tokenizer.convert_tokens_to_string([token]) for token in tokens]

    fig = go.Figure()

    # Create x-axis labels
    x_labels = [f'"{string}" `{token}`' for token, string in zip(tokens, strings)]

    # Add Logits bar
    fig.add_trace(go.Bar(
        name='Logits',
        x=[x - 0.2 for x in range(len(x_labels))],  # Shift left
        y=np.round(top_k_logits_values_np, 2),
        text=np.round(top_k_logits_values_np, 2),
        textposition='auto',
        marker_color='blue',
        width=0.4,  # Reduce bar width
        yaxis='y'
    ))

    # Add Scores bar
    fig.add_trace(go.Bar(
        name='Softmax(Scores)',
        x=[x + 0.2 for x in range(len(x_labels))],  # Shift right
        y=np.round(top_k_scores_values_np, 2),
        text=np.round(top_k_scores_values_np, 2),
        textposition='auto',
        marker_color='red',
        width=0.4,  # Reduce bar width
        yaxis='y2'
    ))

    # Update layout for independent Y axes and other customizations
    fig.update_layout(
        title=f'Top {k} Token Probabilities',
        xaxis_title='Tokens',
        yaxis=dict(title='Logits', side='left'),
        yaxis2=dict(title='Softmax(Scores)', side='right', overlaying='y'),
        xaxis=dict(
            tickangle=45,
            tickmode='array',
            tickvals=list(range(len(tokens))),
            ticktext=strings
        ),
        legend=dict(x=1.1, y=1),
        barmode='group'  # Group bars side by side
    )

    # Show the plot
    fig.show()

visualize_top_k_tokens(output_dict['logits'][0], output_dict['scores'][0], 50)

Let's also create a helper function to generate logits and scores for ease of use.

In [15]:
def generate_logits(prompt, **extra_generation_config):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output_dict = model.generate(inputs.input_ids, max_new_tokens=1, output_scores=True, output_logits=True, return_dict_in_generate=True, **extra_generation_config)

    return {'logits': output_dict['logits'][0], 'scores': output_dict['scores'][0]}

And let's test them both!

In [16]:
logits = generate_logits("The sky is")
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

Greedy search decoding strategy¶

Greedy search is the simplest decoding method. It selects the word with the highest probability as its next word.

image.png

Starting from the word "The", the algorithm greedily chooses the next word of highest probability "nice" and so on, so that the final generated word sequence is ("The", "nice", "woman") having an overall probability of 0.5×0.4=0.2.

In [59]:
prompt = "My name is"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(inputs.input_ids, max_new_tokens=32, streamer=TextStreamer(tokenizer))
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
My name is John. I am a student. I am in Class 3, Grade 7. I have a good friend. His name is Mike. He is a

Sample decoding strategies¶

Those strategies are applied over the logits to compute scores.
They are simple because they only need the current state of the logits to compute the scores

Temperature¶

The temperature parameter is used to modulate the next token probabilities.

scores = logits / temperature

It normally ranges between 0 and 1, some APIs (like Gemini) may allow you to go above 1, but they are most likely using a different formula.

In [60]:
# prompt = '2 + 2 = '
prompt = 'My name is'

logits = generate_logits(prompt, do_sample=True, temperature=0.01)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

We will continue with a temperature of 1 to easly showcase the effects of other parameters.

Top-k¶

The top-k parameter is used to limit the next token choices to the top-k most likely tokens.

In [18]:
prompt = "My name is"

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits(prompt, do_sample=True, temperature=1, top_k=10)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

Top-p¶

The top-p parameter is used to limit the next token choices to the smallest set of most probable tokens with probabilities that add up to top_p or higher.

In [19]:
prompt = "My name is"

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits("My name is", do_sample=True, temperature=1, top_p=0.7)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

Min-P¶

Minimum token probability, which will be scaled by the probability of the most likely token. It must be a value between 0 and 1.
Typical values are in the 0.01-0.2 range, comparably selective as setting top_p in the 0.99-0.8 range (use the opposite of normal top_p values).

In [20]:
prompt = "My name is"

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits(prompt, do_sample=True, temperature=1, min_p=0.5)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

Typical P¶

Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to typical_p or higher are kept for generation. See this paper for more details.

In [21]:
prompt = "My name is"

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits(prompt, do_sample=True, temperature=1, typical_p=0.5)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

Repetition Penalty¶

The repetition penalty parameter is used to penalize tokens that were previously generated, with a decay factor of repetition_penalty.

score["token"] = logit["token"] / repetition_penalty ^ min(word_count["token"], 1)

See this paper for more details.

In [22]:
# [Chorus: Eminem]
# Hi, my name is, what? My name is, who?
# My name is, chka-chka, Slim Shady

# prompt = "Hi, my name is, what? My name is, who?" # 

prompt = "4 + 4 - 4 = "

logits = generate_logits(prompt, do_sample=True, temperature=1)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)

logits = generate_logits(prompt, do_sample=True, temperature=1, repetition_penalty=1.5)
visualize_top_k_tokens(logits['logits'], logits['scores'], 50)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.

TODO: add multi-beam search and decode, and visualize them¶

num_beams (`int`, *optional*, defaults to 1):
            Number of beams for beam search. 1 means no beam search.
num_beam_groups (`int`, *optional*, defaults to 1):
    Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
    [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.

Beam search decoding strategy¶

image.png

In [111]:
from dataclasses import dataclass, field

@dataclass
class BeamNode:
    my_tokens: torch.Tensor
    sequence_tokens: torch.Tensor

    my_score: float
    sequence_score: float
    children: list['BeamNode'] = field(default_factory=list)
    chosen: bool = False

    def add_child(self, child: 'BeamNode'):
        self.children.append(child)

    def text(self):
        return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(self.sequence_tokens))
    
    def my_text(self):
        return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(self.my_tokens))

    def __str__(self):
        return f'{self.score:.10f}: {self.my_text()} ({self.text()})'
    
    def id(self):
        return '_'+'_'.join([str(token) for token in self.sequence_tokens.tolist()])
In [283]:
from tqdm import tqdm

def forward_modal_and_sample(model, tokens, top_k=2, temperature=None):
    model_out = model(input_ids=tokens.unsqueeze(0), attention_mask=torch.ones_like(tokens).unsqueeze(0).to(model.device))

    if temperature is not None and temperature > 0 and temperature <= 1:
        adjusted = model_out['logits'][:, -1, :] / temperature
        probabilities = adjusted.softmax(dim=-1)
        # Sample from probabilities using multinomial to get indices
        sampled_indices = torch.multinomial(probabilities[0], num_samples=top_k)
        values = probabilities[0][sampled_indices]
        return values, sampled_indices


    probabilities = model_out['logits'][:, -1, :].softmax(dim=-1)
    values, indeces = probabilities.topk(top_k)
    return values[0], indeces[0]

def token_to_string(tokens):
    return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(tokens))

def beam_search(model, start_text, nb_beams=3, iterations=10, temperature=0):
    # encode text to tokens
    start_tokens = tokenizer(start_text, return_tensors="pt").to(model.device)['input_ids'][0]

    # create root beam from which all beams will stem
    root = BeamNode(start_tokens, start_tokens, my_score=1, sequence_score=1, chosen=True)

    # initialize a list of beams processed by the last iteration
    last_processed_beams = [root]

    for iteration in tqdm(list(range(iterations))):

        # initialize a list of next beams to process
        next_beams_to_process = []
        
        # for each beam in the last iteration
        for beam in last_processed_beams:
            
            # generate a list of top-k tokens and their scores
            top_k_values, top_k_indeces = forward_modal_and_sample(model, beam.sequence_tokens, nb_beams, temperature=temperature)

            # for each generated token
            for token, token_score in zip(top_k_indeces, top_k_values):

                # create a new beam with the current token
                next_beam = BeamNode(
                    my_tokens=token.unsqueeze(0), 
                    sequence_tokens=torch.cat([beam.sequence_tokens, token.unsqueeze(0)]), 
                    sequence_score=beam.sequence_score * token_score,
                    my_score=token_score,
                    chosen=False
                )

                # add the new beam to the list of next beams to process
                next_beams_to_process.append(next_beam)

                # add the new beam as a child to the current beam processing
                beam.add_child(next_beam)

        # sort the iteration of beams by their sequence score
        next_beams_to_process = sorted(next_beams_to_process, key=lambda x: x.sequence_score, reverse=True)

        # keep top nb_beams
        last_processed_beams = next_beams_to_process[:nb_beams]

        # mark the chosen beam
        for beam in last_processed_beams:
            beam.chosen = True

    # return the found beams, and the root beam used for visualization
    return last_processed_beams, root
In [294]:
# set torch seed
torch.manual_seed(0)

beams, root_beam = beam_search(model, "My name is", nb_beams=5, iterations=5)
for beam in beams:
    print(beam.text())
    print('---')
100%|██████████| 5/5 [00:03<00:00,  1.36it/s]
My name is John Smith. I am
---
My name is John. I am a
---
My name is John. I am 
---
My name is Tom. I am a
---
My name is Jack. I am a
---

In [295]:
from IPython.display import Markdown, display

def show_beam_search(root_beam: BeamNode, out_beams: list[BeamNode]):
    """
    Visualize the beam search process using mermaid.js syntax.
    """

    # initialize the mermaid text
    mermaid_text = """
flowchart TB
    %% Configure direction and layout
    direction TB
    classDef discarded fill:#5e1300
    classDef output fill:#004713
    """

    # process each beam and add its children to the mermaid text
    def process_beam(beam: BeamNode, iteration: int):
        nonlocal mermaid_text
        
        # if beam.children:
        #     mermaid_text += f'    subgraph level_{iteration}[Level {iteration}]\n'

        for child in beam.children:
            clean_beam_text = beam.my_text().replace('"', '').replace('[', '').replace(']', '').replace('\n', '\\n').replace('(', '\\(').replace(')', '\\)') or " "
            clean_child_text = child.my_text().replace('"', '').replace('[', '').replace(']', '').replace('\n', '\\n').replace('(', '\\(').replace(')', '\\)') or " "
            mermaid_text += f'    {beam.id()}["{clean_beam_text}"] -->|Beam: {child.sequence_score:.10f}\nToken: {child.my_score:.10f}| {child.id()}["{clean_child_text}"]{":::discarded" if not child.chosen else ""}'
            mermaid_text += '\n'

            process_beam(child, iteration + 1)

        # if beam.children:
        #     mermaid_text += f'    end\n'

    # process the root beam
    process_beam(root_beam, 0)

    # add the leaf beams to the mermaid text for easy visualization
    for out_beam_idx, beam in enumerate(out_beams):
        mermaid_text += f'    {beam.id()}["{beam.my_text()}"] -->|Output beam {out_beam_idx}| _out_{out_beam_idx}["{beam.text()}"]:::output'
        mermaid_text += '\n'

    # show markdown
    display(Markdown(f'''
```mermaid
{mermaid_text}
```
    '''))

    return mermaid_text

print(show_beam_search(root_beam, beams))
flowchart TB
    %% Configure direction and layout
    direction TB
    classDef discarded fill:#5e1300
    classDef output fill:#004713
        _5050_829_374["My name is"] -->|Beam: 0.0411579572
Token: 0.0411579572| _5050_829_374_3757[" John"]
    _5050_829_374_3757[" John"] -->|Beam: 0.0094370460
Token: 0.2292885035| _5050_829_374_3757_13["."]
    _5050_829_374_3757_13["."] -->|Beam: 0.0051076598
Token: 0.5412350297| _5050_829_374_3757_13_358[" I"]
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0015957108
Token: 0.3124152422| _5050_829_374_3757_13_358_1079[" am"]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0004616437
Token: 0.2893028557| _5050_829_374_3757_13_358_1079_264[" a"]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0003415810
Token: 0.2140619755| _5050_829_374_3757_13_358_1079_220[" "]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0001389844
Token: 0.0870987326| _5050_829_374_3757_13_358_1079_458[" an"]:::discarded
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0001263757
Token: 0.0791971236| _5050_829_374_3757_13_358_1079_504[" from"]:::discarded
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0000478095
Token: 0.0299612693| _5050_829_374_3757_13_358_1079_304[" in"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0009074934
Token: 0.1776730269| _5050_829_374_3757_13_358_2776["'m"]
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0002342992
Token: 0.2581827939| _5050_829_374_3757_13_358_2776_264[" a"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0001624699
Token: 0.1790315062| _5050_829_374_3757_13_358_2776_220[" "]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0001033742
Token: 0.1139118373| _5050_829_374_3757_13_358_2776_504[" from"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0000912654
Token: 0.1005686522| _5050_829_374_3757_13_358_2776_458[" an"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0000343011
Token: 0.0377975740| _5050_829_374_3757_13_358_2776_304[" in"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0005643757
Token: 0.1104959473| _5050_829_374_3757_13_358_614[" have"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0002997114
Token: 0.0586788096| _5050_829_374_3757_13_358_4249["’m"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0002174854
Token: 0.0425802469| _5050_829_374_3757_13_358_3887[" live"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0007594973
Token: 0.0804804042| _5050_829_374_3757_13_3017[" My"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0003413731
Token: 0.0361737236| _5050_829_374_3757_13_1096[" This"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0001402714
Token: 0.0148639129| _5050_829_374_3757_13_11201[" Today"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0001202185
Token: 0.0127389990| _5050_829_374_3757_13_1084[" It"]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0080531100
Token: 0.1956635118| _5050_829_374_3757_9082[" Smith"]
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0037134732
Token: 0.4611228704| _5050_829_374_3757_9082_13["."]
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0015315652
Token: 0.4124346972| _5050_829_374_3757_9082_13_358[" I"]
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0005084904
Token: 0.3320069909| _5050_829_374_3757_9082_13_358_1079[" am"]
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0002385156
Token: 0.1557332426| _5050_829_374_3757_9082_13_358_2776["'m"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0001708670
Token: 0.1115636304| _5050_829_374_3757_9082_13_358_614[" have"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0000993508
Token: 0.0648687780| _5050_829_374_3757_9082_13_358_4249["’m"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0000612552
Token: 0.0399951488| _5050_829_374_3757_9082_13_358_572[" was"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0003693283
Token: 0.0994563103| _5050_829_374_3757_9082_13_3017[" My"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0001343704
Token: 0.0361845754| _5050_829_374_3757_9082_13_1096[" This"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0000975317
Token: 0.0262642782| _5050_829_374_3757_9082_13_5209[" Please"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0000638637
Token: 0.0171978455| _5050_829_374_3757_9082_13_576[" The"]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0013892787
Token: 0.1725145578| _5050_829_374_3757_9082_323[" and"]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0006775868
Token: 0.0841397718| _5050_829_374_3757_9082_11[","]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0006074440
Token: 0.0754297376| _5050_829_374_3757_9082_58883["."]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0003360215
Token: 0.0417256765| _5050_829_374_3757_9082_3837[","]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0041128518
Token: 0.0999284759| _5050_829_374_3757_323[" and"]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0028648935
Token: 0.0696072802| _5050_829_374_3757_58883["."]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0017616025
Token: 0.0428010188| _5050_829_374_3757_11[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0233974829
Token: 0.0233974829| _5050_829_374_8364[" Tom"]
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0065385150
Token: 0.2794537842| _5050_829_374_8364_13["."]
    _5050_829_374_8364_13["."] -->|Beam: 0.0037125414
Token: 0.5677958131| _5050_829_374_8364_13_358[" I"]
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0010795386
Token: 0.2907815576| _5050_829_374_8364_13_358_1079[" am"]
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0003097244
Token: 0.2869044542| _5050_829_374_8364_13_358_1079_264[" a"]
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0001624784
Token: 0.1505072862| _5050_829_374_8364_13_358_1079_220[" "]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0001167530
Token: 0.1081508547| _5050_829_374_8364_13_358_1079_458[" an"]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0000959194
Token: 0.0888522416| _5050_829_374_8364_13_358_1079_504[" from"]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0000505730
Token: 0.0468468554| _5050_829_374_8364_13_358_1079_304[" in"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0008384187
Token: 0.2258341610| _5050_829_374_8364_13_358_2776["'m"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0004936044
Token: 0.1329559237| _5050_829_374_8364_13_358_614[" have"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0001844683
Token: 0.0496878736| _5050_829_374_8364_13_358_3887[" live"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0001795371
Token: 0.0483596213| _5050_829_374_8364_13_358_4249["’m"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0003877403
Token: 0.0593009703| _5050_829_374_8364_13_3017[" My"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0003446365
Token: 0.0527086817| _5050_829_374_8364_13_1096[" This"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0001503224
Token: 0.0229902919| _5050_829_374_8364_13_9909["("]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0001027653
Token: 0.0157169104| _5050_829_374_8364_13_42344[" ("]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0036408773
Token: 0.1556097865| _5050_829_374_8364_58883["."]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0016900066
Token: 0.0722302720| _5050_829_374_8364_323[" and"]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0014839880
Token: 0.0634251162| _5050_829_374_8364_9082[" Smith"]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0010783494
Token: 0.0460882671| _5050_829_374_8364_3837[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0169425271
Token: 0.0169425271| _5050_829_374_10244[" Mary"]
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0043946323
Token: 0.2593846917| _5050_829_374_10244_13["."]
    _5050_829_374_10244_13["."] -->|Beam: 0.0027235150
Token: 0.6197367311| _5050_829_374_10244_13_358[" I"]
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0008050094
Token: 0.2955773771| _5050_829_374_10244_13_358_1079[" am"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0005264900
Token: 0.1933126897| _5050_829_374_10244_13_358_2776["'m"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0003567809
Token: 0.1310001612| _5050_829_374_10244_13_358_614[" have"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0001402669
Token: 0.0515021682| _5050_829_374_10244_13_358_3887[" live"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0001339833
Token: 0.0491950065| _5050_829_374_10244_13_358_4249["’m"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0002790856
Token: 0.0635060146| _5050_829_374_10244_13_3017[" My"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0001797947
Token: 0.0409123376| _5050_829_374_10244_13_1096[" This"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0000633895
Token: 0.0144243091| _5050_829_374_10244_13_5692[" Here"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0000536795
Token: 0.0122147910| _5050_829_374_10244_13_9909["("]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0020185323
Token: 0.1191399693| _5050_829_374_10244_9082[" Smith"]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0014367392
Token: 0.0848007649| _5050_829_374_10244_323[" and"]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0011306301
Token: 0.0667332634| _5050_829_374_10244_58883["."]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0006669579
Token: 0.0393659025| _5050_829_374_10244_11[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0161256213
Token: 0.0161256213| _5050_829_374_6798[" David"]
    _5050_829_374_6798[" David"] -->|Beam: 0.0032928335
Token: 0.2041988522| _5050_829_374_6798_13["."]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0016449407
Token: 0.1020078957| _5050_829_374_6798_323[" and"]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0010011267
Token: 0.0620829836| _5050_829_374_6798_11[","]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0006185918
Token: 0.0383608043| _5050_829_374_6798_9082[" Smith"]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0004227330
Token: 0.0262149870| _5050_829_374_6798_58883["."]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0152858030
Token: 0.0152858030| _5050_829_374_7607[" Jack"]
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0053940336
Token: 0.3528786600| _5050_829_374_7607_13["."]
    _5050_829_374_7607_13["."] -->|Beam: 0.0031685214
Token: 0.5874122381| _5050_829_374_7607_13_358[" I"]
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0009475948
Token: 0.2990652919| _5050_829_374_7607_13_358_1079[" am"]
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0002867103
Token: 0.3025663197| _5050_829_374_7607_13_358_1079_264[" a"]
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0001532006
Token: 0.1616730839| _5050_829_374_7607_13_358_1079_220[" "]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000864749
Token: 0.0912572592| _5050_829_374_7607_13_358_1079_504[" from"]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000846806
Token: 0.0893637389| _5050_829_374_7607_13_358_1079_458[" an"]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000356248
Token: 0.0375949740| _5050_829_374_7607_13_358_1079_29235[" twelve"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0007536290
Token: 0.2378487885| _5050_829_374_7607_13_358_2776["'m"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0003680342
Token: 0.1161532849| _5050_829_374_7607_13_358_614[" have"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0001667282
Token: 0.0526201837| _5050_829_374_7607_13_358_3887[" live"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0001425381
Token: 0.0449856855| _5050_829_374_7607_13_358_4249["’m"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0003242713
Token: 0.0601166561| _5050_829_374_7607_13_3017[" My"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0002582566
Token: 0.0478781909| _5050_829_374_7607_13_1096[" This"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0001329789
Token: 0.0246529635| _5050_829_374_7607_13_7607[" Jack"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0001045688
Token: 0.0193860177| _5050_829_374_7607_13_11201[" Today"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0015295137
Token: 0.1000610664| _5050_829_374_7607_323[" and"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0008357323
Token: 0.0546737611| _5050_829_374_7607_58883["."]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0007785768
Token: 0.0509346314| _5050_829_374_7607_9082[" Smith"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0007304933
Token: 0.0477890074| _5050_829_374_7607_11[","]:::discarded
    _5050_829_374_3757_9082_13_358_1079[" am"] -->|Output beam 0| _out_0["My name is John Smith. I am"]:::output
    _5050_829_374_3757_13_358_1079_264[" a"] -->|Output beam 1| _out_1["My name is John. I am a"]:::output
    _5050_829_374_3757_13_358_1079_220[" "] -->|Output beam 2| _out_2["My name is John. I am "]:::output
    _5050_829_374_8364_13_358_1079_264[" a"] -->|Output beam 3| _out_3["My name is Tom. I am a"]:::output
    _5050_829_374_7607_13_358_1079_264[" a"] -->|Output beam 4| _out_4["My name is Jack. I am a"]:::output
flowchart TB
    %% Configure direction and layout
    direction TB
    classDef discarded fill:#5e1300
    classDef output fill:#004713
        _5050_829_374["My name is"] -->|Beam: 0.0411579572
Token: 0.0411579572| _5050_829_374_3757[" John"]
    _5050_829_374_3757[" John"] -->|Beam: 0.0094370460
Token: 0.2292885035| _5050_829_374_3757_13["."]
    _5050_829_374_3757_13["."] -->|Beam: 0.0051076598
Token: 0.5412350297| _5050_829_374_3757_13_358[" I"]
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0015957108
Token: 0.3124152422| _5050_829_374_3757_13_358_1079[" am"]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0004616437
Token: 0.2893028557| _5050_829_374_3757_13_358_1079_264[" a"]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0003415810
Token: 0.2140619755| _5050_829_374_3757_13_358_1079_220[" "]
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0001389844
Token: 0.0870987326| _5050_829_374_3757_13_358_1079_458[" an"]:::discarded
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0001263757
Token: 0.0791971236| _5050_829_374_3757_13_358_1079_504[" from"]:::discarded
    _5050_829_374_3757_13_358_1079[" am"] -->|Beam: 0.0000478095
Token: 0.0299612693| _5050_829_374_3757_13_358_1079_304[" in"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0009074934
Token: 0.1776730269| _5050_829_374_3757_13_358_2776["'m"]
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0002342992
Token: 0.2581827939| _5050_829_374_3757_13_358_2776_264[" a"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0001624699
Token: 0.1790315062| _5050_829_374_3757_13_358_2776_220[" "]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0001033742
Token: 0.1139118373| _5050_829_374_3757_13_358_2776_504[" from"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0000912654
Token: 0.1005686522| _5050_829_374_3757_13_358_2776_458[" an"]:::discarded
    _5050_829_374_3757_13_358_2776["'m"] -->|Beam: 0.0000343011
Token: 0.0377975740| _5050_829_374_3757_13_358_2776_304[" in"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0005643757
Token: 0.1104959473| _5050_829_374_3757_13_358_614[" have"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0002997114
Token: 0.0586788096| _5050_829_374_3757_13_358_4249["’m"]:::discarded
    _5050_829_374_3757_13_358[" I"] -->|Beam: 0.0002174854
Token: 0.0425802469| _5050_829_374_3757_13_358_3887[" live"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0007594973
Token: 0.0804804042| _5050_829_374_3757_13_3017[" My"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0003413731
Token: 0.0361737236| _5050_829_374_3757_13_1096[" This"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0001402714
Token: 0.0148639129| _5050_829_374_3757_13_11201[" Today"]:::discarded
    _5050_829_374_3757_13["."] -->|Beam: 0.0001202185
Token: 0.0127389990| _5050_829_374_3757_13_1084[" It"]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0080531100
Token: 0.1956635118| _5050_829_374_3757_9082[" Smith"]
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0037134732
Token: 0.4611228704| _5050_829_374_3757_9082_13["."]
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0015315652
Token: 0.4124346972| _5050_829_374_3757_9082_13_358[" I"]
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0005084904
Token: 0.3320069909| _5050_829_374_3757_9082_13_358_1079[" am"]
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0002385156
Token: 0.1557332426| _5050_829_374_3757_9082_13_358_2776["'m"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0001708670
Token: 0.1115636304| _5050_829_374_3757_9082_13_358_614[" have"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0000993508
Token: 0.0648687780| _5050_829_374_3757_9082_13_358_4249["’m"]:::discarded
    _5050_829_374_3757_9082_13_358[" I"] -->|Beam: 0.0000612552
Token: 0.0399951488| _5050_829_374_3757_9082_13_358_572[" was"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0003693283
Token: 0.0994563103| _5050_829_374_3757_9082_13_3017[" My"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0001343704
Token: 0.0361845754| _5050_829_374_3757_9082_13_1096[" This"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0000975317
Token: 0.0262642782| _5050_829_374_3757_9082_13_5209[" Please"]:::discarded
    _5050_829_374_3757_9082_13["."] -->|Beam: 0.0000638637
Token: 0.0171978455| _5050_829_374_3757_9082_13_576[" The"]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0013892787
Token: 0.1725145578| _5050_829_374_3757_9082_323[" and"]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0006775868
Token: 0.0841397718| _5050_829_374_3757_9082_11[","]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0006074440
Token: 0.0754297376| _5050_829_374_3757_9082_58883["."]:::discarded
    _5050_829_374_3757_9082[" Smith"] -->|Beam: 0.0003360215
Token: 0.0417256765| _5050_829_374_3757_9082_3837[","]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0041128518
Token: 0.0999284759| _5050_829_374_3757_323[" and"]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0028648935
Token: 0.0696072802| _5050_829_374_3757_58883["."]:::discarded
    _5050_829_374_3757[" John"] -->|Beam: 0.0017616025
Token: 0.0428010188| _5050_829_374_3757_11[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0233974829
Token: 0.0233974829| _5050_829_374_8364[" Tom"]
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0065385150
Token: 0.2794537842| _5050_829_374_8364_13["."]
    _5050_829_374_8364_13["."] -->|Beam: 0.0037125414
Token: 0.5677958131| _5050_829_374_8364_13_358[" I"]
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0010795386
Token: 0.2907815576| _5050_829_374_8364_13_358_1079[" am"]
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0003097244
Token: 0.2869044542| _5050_829_374_8364_13_358_1079_264[" a"]
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0001624784
Token: 0.1505072862| _5050_829_374_8364_13_358_1079_220[" "]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0001167530
Token: 0.1081508547| _5050_829_374_8364_13_358_1079_458[" an"]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0000959194
Token: 0.0888522416| _5050_829_374_8364_13_358_1079_504[" from"]:::discarded
    _5050_829_374_8364_13_358_1079[" am"] -->|Beam: 0.0000505730
Token: 0.0468468554| _5050_829_374_8364_13_358_1079_304[" in"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0008384187
Token: 0.2258341610| _5050_829_374_8364_13_358_2776["'m"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0004936044
Token: 0.1329559237| _5050_829_374_8364_13_358_614[" have"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0001844683
Token: 0.0496878736| _5050_829_374_8364_13_358_3887[" live"]:::discarded
    _5050_829_374_8364_13_358[" I"] -->|Beam: 0.0001795371
Token: 0.0483596213| _5050_829_374_8364_13_358_4249["’m"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0003877403
Token: 0.0593009703| _5050_829_374_8364_13_3017[" My"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0003446365
Token: 0.0527086817| _5050_829_374_8364_13_1096[" This"]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0001503224
Token: 0.0229902919| _5050_829_374_8364_13_9909["("]:::discarded
    _5050_829_374_8364_13["."] -->|Beam: 0.0001027653
Token: 0.0157169104| _5050_829_374_8364_13_42344[" ("]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0036408773
Token: 0.1556097865| _5050_829_374_8364_58883["."]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0016900066
Token: 0.0722302720| _5050_829_374_8364_323[" and"]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0014839880
Token: 0.0634251162| _5050_829_374_8364_9082[" Smith"]:::discarded
    _5050_829_374_8364[" Tom"] -->|Beam: 0.0010783494
Token: 0.0460882671| _5050_829_374_8364_3837[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0169425271
Token: 0.0169425271| _5050_829_374_10244[" Mary"]
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0043946323
Token: 0.2593846917| _5050_829_374_10244_13["."]
    _5050_829_374_10244_13["."] -->|Beam: 0.0027235150
Token: 0.6197367311| _5050_829_374_10244_13_358[" I"]
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0008050094
Token: 0.2955773771| _5050_829_374_10244_13_358_1079[" am"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0005264900
Token: 0.1933126897| _5050_829_374_10244_13_358_2776["'m"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0003567809
Token: 0.1310001612| _5050_829_374_10244_13_358_614[" have"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0001402669
Token: 0.0515021682| _5050_829_374_10244_13_358_3887[" live"]:::discarded
    _5050_829_374_10244_13_358[" I"] -->|Beam: 0.0001339833
Token: 0.0491950065| _5050_829_374_10244_13_358_4249["’m"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0002790856
Token: 0.0635060146| _5050_829_374_10244_13_3017[" My"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0001797947
Token: 0.0409123376| _5050_829_374_10244_13_1096[" This"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0000633895
Token: 0.0144243091| _5050_829_374_10244_13_5692[" Here"]:::discarded
    _5050_829_374_10244_13["."] -->|Beam: 0.0000536795
Token: 0.0122147910| _5050_829_374_10244_13_9909["("]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0020185323
Token: 0.1191399693| _5050_829_374_10244_9082[" Smith"]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0014367392
Token: 0.0848007649| _5050_829_374_10244_323[" and"]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0011306301
Token: 0.0667332634| _5050_829_374_10244_58883["."]:::discarded
    _5050_829_374_10244[" Mary"] -->|Beam: 0.0006669579
Token: 0.0393659025| _5050_829_374_10244_11[","]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0161256213
Token: 0.0161256213| _5050_829_374_6798[" David"]
    _5050_829_374_6798[" David"] -->|Beam: 0.0032928335
Token: 0.2041988522| _5050_829_374_6798_13["."]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0016449407
Token: 0.1020078957| _5050_829_374_6798_323[" and"]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0010011267
Token: 0.0620829836| _5050_829_374_6798_11[","]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0006185918
Token: 0.0383608043| _5050_829_374_6798_9082[" Smith"]:::discarded
    _5050_829_374_6798[" David"] -->|Beam: 0.0004227330
Token: 0.0262149870| _5050_829_374_6798_58883["."]:::discarded
    _5050_829_374["My name is"] -->|Beam: 0.0152858030
Token: 0.0152858030| _5050_829_374_7607[" Jack"]
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0053940336
Token: 0.3528786600| _5050_829_374_7607_13["."]
    _5050_829_374_7607_13["."] -->|Beam: 0.0031685214
Token: 0.5874122381| _5050_829_374_7607_13_358[" I"]
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0009475948
Token: 0.2990652919| _5050_829_374_7607_13_358_1079[" am"]
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0002867103
Token: 0.3025663197| _5050_829_374_7607_13_358_1079_264[" a"]
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0001532006
Token: 0.1616730839| _5050_829_374_7607_13_358_1079_220[" "]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000864749
Token: 0.0912572592| _5050_829_374_7607_13_358_1079_504[" from"]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000846806
Token: 0.0893637389| _5050_829_374_7607_13_358_1079_458[" an"]:::discarded
    _5050_829_374_7607_13_358_1079[" am"] -->|Beam: 0.0000356248
Token: 0.0375949740| _5050_829_374_7607_13_358_1079_29235[" twelve"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0007536290
Token: 0.2378487885| _5050_829_374_7607_13_358_2776["'m"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0003680342
Token: 0.1161532849| _5050_829_374_7607_13_358_614[" have"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0001667282
Token: 0.0526201837| _5050_829_374_7607_13_358_3887[" live"]:::discarded
    _5050_829_374_7607_13_358[" I"] -->|Beam: 0.0001425381
Token: 0.0449856855| _5050_829_374_7607_13_358_4249["’m"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0003242713
Token: 0.0601166561| _5050_829_374_7607_13_3017[" My"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0002582566
Token: 0.0478781909| _5050_829_374_7607_13_1096[" This"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0001329789
Token: 0.0246529635| _5050_829_374_7607_13_7607[" Jack"]:::discarded
    _5050_829_374_7607_13["."] -->|Beam: 0.0001045688
Token: 0.0193860177| _5050_829_374_7607_13_11201[" Today"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0015295137
Token: 0.1000610664| _5050_829_374_7607_323[" and"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0008357323
Token: 0.0546737611| _5050_829_374_7607_58883["."]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0007785768
Token: 0.0509346314| _5050_829_374_7607_9082[" Smith"]:::discarded
    _5050_829_374_7607[" Jack"] -->|Beam: 0.0007304933
Token: 0.0477890074| _5050_829_374_7607_11[","]:::discarded
    _5050_829_374_3757_9082_13_358_1079[" am"] -->|Output beam 0| _out_0["My name is John Smith. I am"]:::output
    _5050_829_374_3757_13_358_1079_264[" a"] -->|Output beam 1| _out_1["My name is John. I am a"]:::output
    _5050_829_374_3757_13_358_1079_220[" "] -->|Output beam 2| _out_2["My name is John. I am "]:::output
    _5050_829_374_8364_13_358_1079_264[" a"] -->|Output beam 3| _out_3["My name is Tom. I am a"]:::output
    _5050_829_374_7607_13_358_1079_264[" a"] -->|Output beam 4| _out_4["My name is Jack. I am a"]:::output

In [296]:
from IPython.display import Markdown, display

def show_beam_search(root_beam: BeamNode, out_beams: list[BeamNode]):
    """
    Visualize the beam search process using mermaid.js mindmap syntax.
    """
    # initialize the mermaid text
    mermaid_text = f"""
mindmap
  root(({root_beam.text()}))
"""

    def process_beam(beam: BeamNode, indent_level: int):
        nonlocal mermaid_text
        indent = "  " * indent_level  # Mindmap requires 2 spaces for indentation

        for child in beam.children:
            # Clean and format the text to avoid special characters
            clean_text = child.my_text().replace('"', '').replace('[', '').replace(']', '').replace('\n', '\\n') or " "
            score = f"{child.sequence_score:.4f}"
            
            # Each line must start with indent + one of these symbols: + - * etc
            mermaid_text += f'{indent}  + node["{clean_text}"]\n'
            
            process_beam(child, indent_level + 1)

        if beam.chosen and len(beam.children) == 0:
            clean_text = beam.text().replace('"', '').replace('[', '').replace(']', '').replace('\n', '\\n') or " "
            mermaid_text += f'{indent}  + output["OUTPUT: {clean_text}"]\n'

    # process the root beam's children
    process_beam(root_beam, 1)

    # show markdown
    display(Markdown(f'''
```mermaid
{mermaid_text}
```
    '''))

    return mermaid_text

print(show_beam_search(root_beam, beams))
mindmap
  root((My name is))
    + node[" John"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is John. I am a"]
            + node[" "]
              + output["OUTPUT: My name is John. I am "]
            + node[" an"]
            + node[" from"]
            + node[" in"]
          + node["'m"]
            + node[" a"]
            + node[" "]
            + node[" from"]
            + node[" an"]
            + node[" in"]
          + node[" have"]
          + node["’m"]
          + node[" live"]
        + node[" My"]
        + node[" This"]
        + node[" Today"]
        + node[" It"]
      + node[" Smith"]
        + node["."]
          + node[" I"]
            + node[" am"]
              + output["OUTPUT: My name is John Smith. I am"]
            + node["'m"]
            + node[" have"]
            + node["’m"]
            + node[" was"]
          + node[" My"]
          + node[" This"]
          + node[" Please"]
          + node[" The"]
        + node[" and"]
        + node[","]
        + node["."]
        + node[","]
      + node[" and"]
      + node["."]
      + node[","]
    + node[" Tom"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is Tom. I am a"]
            + node[" "]
            + node[" an"]
            + node[" from"]
            + node[" in"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node["("]
        + node[" ("]
      + node["."]
      + node[" and"]
      + node[" Smith"]
      + node[","]
    + node[" Mary"]
      + node["."]
        + node[" I"]
          + node[" am"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node[" Here"]
        + node["("]
      + node[" Smith"]
      + node[" and"]
      + node["."]
      + node[","]
    + node[" David"]
      + node["."]
      + node[" and"]
      + node[","]
      + node[" Smith"]
      + node["."]
    + node[" Jack"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is Jack. I am a"]
            + node[" "]
            + node[" from"]
            + node[" an"]
            + node[" twelve"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node[" Jack"]
        + node[" Today"]
      + node[" and"]
      + node["."]
      + node[" Smith"]
      + node[","]
mindmap
  root((My name is))
    + node[" John"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is John. I am a"]
            + node[" "]
              + output["OUTPUT: My name is John. I am "]
            + node[" an"]
            + node[" from"]
            + node[" in"]
          + node["'m"]
            + node[" a"]
            + node[" "]
            + node[" from"]
            + node[" an"]
            + node[" in"]
          + node[" have"]
          + node["’m"]
          + node[" live"]
        + node[" My"]
        + node[" This"]
        + node[" Today"]
        + node[" It"]
      + node[" Smith"]
        + node["."]
          + node[" I"]
            + node[" am"]
              + output["OUTPUT: My name is John Smith. I am"]
            + node["'m"]
            + node[" have"]
            + node["’m"]
            + node[" was"]
          + node[" My"]
          + node[" This"]
          + node[" Please"]
          + node[" The"]
        + node[" and"]
        + node[","]
        + node["."]
        + node[","]
      + node[" and"]
      + node["."]
      + node[","]
    + node[" Tom"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is Tom. I am a"]
            + node[" "]
            + node[" an"]
            + node[" from"]
            + node[" in"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node["("]
        + node[" ("]
      + node["."]
      + node[" and"]
      + node[" Smith"]
      + node[","]
    + node[" Mary"]
      + node["."]
        + node[" I"]
          + node[" am"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node[" Here"]
        + node["("]
      + node[" Smith"]
      + node[" and"]
      + node["."]
      + node[","]
    + node[" David"]
      + node["."]
      + node[" and"]
      + node[","]
      + node[" Smith"]
      + node["."]
    + node[" Jack"]
      + node["."]
        + node[" I"]
          + node[" am"]
            + node[" a"]
              + output["OUTPUT: My name is Jack. I am a"]
            + node[" "]
            + node[" from"]
            + node[" an"]
            + node[" twelve"]
          + node["'m"]
          + node[" have"]
          + node[" live"]
          + node["’m"]
        + node[" My"]
        + node[" This"]
        + node[" Jack"]
        + node[" Today"]
      + node[" and"]
      + node["."]
      + node[" Smith"]
      + node[","]

In [298]:
inputs = tokenizer("My name is", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=5, num_beams=5, num_return_sequences=5)

for output in outputs:
    print(tokenizer.decode(output, skip_special_tokens=True))
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
My name is John Smith. I am
My name is John. I am a
My name is John. I am 
My name is Tom. I am a
My name is Jack. I am a
In [ ]: